import os
import cv2
import numpy as np
# import plotly


# from plotly.graph_objs import Scatter
# from plotly.graph_objs.scatter import Line
#
#
# # Plots min, max and mean + standard deviation bars of a population over time
# def lineplot(xs, ys_population, title, path='', xaxis='episode'):
#     max_colour, mean_colour, std_colour, transparent = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)', 'rgba(0, 0, 0, 0)'
#
#     if isinstance(ys_population[0], list) or isinstance(ys_population[0], tuple):
#         ys = np.asarray(ys_population, dtype=np.float32)
#         ys_min, ys_max, ys_mean, ys_std, ys_median = ys.min(1), ys.max(1), ys.mean(1), ys.std(1), np.median(ys, 1)
#         ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std
#
#         trace_max = Scatter(x=xs, y=ys_max, line=Line(color=max_colour, dash='dash'), name='Max')
#         trace_upper = Scatter(x=xs, y=ys_upper, line=Line(color=transparent), name='+1 Std. Dev.', showlegend=False)
#         trace_mean = Scatter(x=xs, y=ys_mean, fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour),
#                              name='Mean')
#         trace_lower = Scatter(x=xs, y=ys_lower, fill='tonexty', fillcolor=std_colour, line=Line(color=transparent),
#                               name='-1 Std. Dev.', showlegend=False)
#         trace_min = Scatter(x=xs, y=ys_min, line=Line(color=max_colour, dash='dash'), name='Min')
#         trace_median = Scatter(x=xs, y=ys_median, line=Line(color=max_colour), name='Median')
#         data = [trace_upper, trace_mean, trace_lower, trace_min, trace_max, trace_median]
#     else:
#         data = [Scatter(x=xs, y=ys_population, line=Line(color=mean_colour))]
#     plotly.offline.plot({
#         'data': data,
#         'layout': dict(title=title, xaxis={'title': xaxis}, yaxis={'title': title})
#     }, filename=os.path.join(path, title + '.html'), auto_open=False)


def write_video(frames, title, path=''):
    frames = np.multiply(np.stack(frames, axis=0).transpose(0, 2, 3, 1), 255).clip(0, 255).astype(np.uint8)[:, :, :,
             ::-1]  # VideoWrite expects H x W x C in BGR
    _, H, W, _ = frames.shape
    writer = cv2.VideoWriter(os.path.join(path, '%s.mp4' % title), cv2.VideoWriter_fourcc(*'mp4v'), 30., (W, H), True)
    for frame in frames:
        writer.write(frame)
    writer.release()


def transform_info(all_infos):
    """ Input: All info is a nested list with the index of [episode][time]{info_key:info_value}
        Output: transformed_infos is a dictionary with the index of [info_key][episode][time]
    """
    if len(all_infos) == 0:
        return []
    transformed_info = {}
    num_episode = len(all_infos)
    T = len(all_infos[0])

    for info_name in all_infos[0][0].keys():
        infos = np.zeros([num_episode, T], dtype=np.float32)
        for i in range(num_episode):
            infos[i, :] = np.array([info[info_name] for info in all_infos[i]])
        transformed_info[info_name] = infos
    return transformed_info
